import argparse
import random
import sys
from ast import literal_eval
from collections import defaultdict
from functools import partial
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
import os
import prompts
from collections import defaultdict
import numpy as np
import spacy
nlp = spacy.load("en_core_web_lg")


ROOT_DIR = Path(__file__).parent
sys.path.insert(0, ROOT_DIR.as_posix())

from model.models import (
    forward_blip,
    forward_blip_text,
    forward_clip,
    forward_clip_text,
    forward_egovlpv2,
    forward_egovlpv2_text,
    forward_languagebind,
    forward_languagebind_text,
    init_BLIP,
    init_CLIP,
    init_EgoVLPv2,
    init_languagebind,
)

CUDA_DEVICE = "cuda:0"
EMBEDDING_DIR = "./embeddings"
VIDEO_DIR = "./data"

parser = argparse.ArgumentParser(
    "Script to perform Composed Video Retrieval on EgoCVR dataset"
)
available_prompts = [f'prompts.{x}' for x in prompts.__dict__.keys() if '__' not in x]
parser.add_argument("--gpt_prompt", default='prompts.mllm_CoT_target_video_description', type=str,
                    choices=available_prompts,
                    help='Denotes the base prompt to use alongside GPT4V. Has to be available in prompts.py')
parser.add_argument("--openai_engine", default='gpt-4o-20241120', type=str,
                    choices=["gpt-35-turbo-20220309",
                             "gpt-35-turbo-16k-20230613",
                             "gpt-35-turbo-20230613",
                             "gpt-35-turbo-1106",
                             "gpt-4-20230321",
                             "gpt-4-20230613",
                             "gpt-4-32k-20230321",
                             "gpt-4-32k-20230613",
                             "gpt-4-1106-preview",
                             "gpt-4-0125-preview",
                             "gpt-4-visual-preview",
                             "gpt-4-turbo-20240409",
                             "gpt-4o-20240513",
                             "gpt-4o-20240806",
                             "gpt-4o-20241120",
                             "gpt-4o-mini-20240718",
                             "gpt-4.5-preview-20250227",
                             "gpt-4.1-20250414",
                             "o3-20250416"],
                    help='Openai LLM Engine to use.')

parser.add_argument(
    "--models",
    nargs="*",
    default=["languagebind", "egovlpv2"],
    type=str,
    help="Which models to use for retrieval.",
)
parser.add_argument(
    "--modalities",
    default=["visual", "text"],
    nargs="*",
    type=str,
    help="Query modalities to use for retrieval.",
)
parser.add_argument(
    "--evaluation",
    default="global",
    choices=["local", "global"],
    type=str,
    help="Type of evaluation. Local: within the same video, Global: across all videos",
)
parser.add_argument(
    "--finetuned",
    action="store_true",
    help="Use finetuned CVR model if available (only BLIP).",
)
parser.add_argument(
    "--query_frames", default=15, type=int, help="Number of video query frames."
)
parser.add_argument(
    "--target_frames", default=15, type=int, help="Number of video target frames."
)
parser.add_argument(
    "--text",
    default="tfcvr",
    choices=["instruction", "tfcvr", "gt"],
    type=str,
    help="Type of query text to use for retrieval. instruction: instruction text, tfcvr: modified captions, gt: target clip narration",
)
parser.add_argument(
    "--fusion",
    default="avg",
    choices=["crossattn", "avg"],
    type=str,
    help="Query fusion strategy when using visual-text modality.",
)
parser.add_argument(
    "--min_gallery_size", default=2, type=int, help="Minimum gallery size. default=2"
)
parser.add_argument(
    "--no_precomputed", action="store_true", help="Do not use precomputed embeddings."
)
parser.add_argument(
    "--neighbors",
    default=15,
    type=int,
    help="Number of neighbors to use for the first stage of 2-stage retrieval.",
)

parser.add_argument(
    "--csv_path", type=str, help="Path to the CSV file containing the target clips' details."
)

parser.add_argument(
    "--embeddings_dir", type=str,
    default="/data1/yjgroup/tym/lab_sync_mac/ego/egocvr/EgoCVR/new_embeddings/full_avg_15/",
    help="Directory to load preprocessed embeddings."
)

parser.add_argument(
    "--output_dir", type=str, default="./visualization", help="Directory to save the embeddings."
)

args = parser.parse_args()

#####################
###### CONFIG #######
#####################
config = {
    "blip": {
        "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv",
        "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_blip-large.csv",
        "ckpt_path_finetuned": "./checkpoints/webvid-covr.ckpt",
        "ckpt_path_notfinetuned": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth",
        "video_folder": f"{VIDEO_DIR}/egocvr_clips",
    },
    "egovlpv2": {
        "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv",
        "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_EgoVLPv2.csv",
        "ckpt_path": "./checkpoints/EgoVLPv2.pth",
        "video_folder": f"{VIDEO_DIR}/egocvr_clips_256",
    },
    "clip": {
        "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv",
        "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_ViT-L-14_datacomp_xl_s13b_b90k.csv",
        "video_folder": f"{VIDEO_DIR}/egocvr_clips",
    },
    "languagebind": {
        "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv",
        "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_LanguageBind.csv",
        "video_folder": f"{VIDEO_DIR}/egocvr_clips",
    },
}

modalities = args.modalities
assert len(modalities) <= 2, "We implemented only 2 stages"
evaluation = args.evaluation
finetuned = args.finetuned
num_query_frames = args.query_frames
num_target_frames = args.target_frames
fusion = args.fusion
text_variant = args.text
min_gallery_size = args.min_gallery_size
no_precomputed = args.no_precomputed
num_neighbors = args.neighbors

# Recalls
recalls = [1, 2, 3]

if "blip" in args.models:
    config["blip"]["ckpt_path"] = (
        config["blip"]["ckpt_path_finetuned"]
        if finetuned
        else config["blip"]["ckpt_path_notfinetuned"]
    )

query_frame_method = "middle" if num_query_frames == 1 else "sample"
if text_variant == "tfcvr":
    text_variant = "modified_captions"
elif text_variant == "gt":
    text_variant = "target_clip_narration"
else:
    text_variant = "instruction"

for _, config_ in config.items():
    config_["embedding_path_raw"] = (
        config_["embedding_path"].replace(".csv", ".pt")
        if Path(config_["embedding_path"].replace(".csv", ".pt")).exists()
        else None
    )

assert len(args.models) == len(args.modalities)


def seed_everything(seed=42):
    # Set Python seed
    random.seed(seed)

    # Set NumPy seed
    np.random.seed(seed)

    # Set PyTorch seed for CPU
    torch.manual_seed(seed)

    # Set PyTorch seed for GPU, if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def get_mid_frame_as_np_array(video_path: str, start_frame: int, num_frames: int):
    """
    Opens the given video at `video_path`, seeks to the midpoint of the specified narrow segment,
    reads one frame, and returns it as a NumPy array (BGR).
    Returns None if any step fails.
    """
    if not os.path.exists(video_path):
        print(f"[WARN] Video file not found: {video_path}")
        return None

    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames < 1:
        print(f"[WARN] No frames found in {video_path}.")
        cap.release()
        return None

    # Calculate the midpoint of the specified narrow segment
    midpoint_frame = start_frame + num_frames // 2

    # Ensure midpoint_frame is valid
    if midpoint_frame >= total_frames:
        print(f"[WARN] Calculated midpoint_frame {midpoint_frame} exceeds total frames {total_frames}.")
        cap.release()
        return None

    cap.set(cv2.CAP_PROP_POS_FRAMES, midpoint_frame)
    ret, frame = cap.read()
    cap.release()

    if not ret or frame is None:
        print(f"[WARN] Could not read frame {midpoint_frame} from {video_path}")
        return None

    return frame  # BGR np array

def build_person_videos_map(short_terms_dir):
    """
    Scans the directory structure of `short_terms_dir` to build a dictionary:
      person_id -> list of video_uids
    """
    person_videos_map = {}

    # 1) List the subfolders for each person ID
    #    e.g. '39', '30', '137', ...
    for person_id in os.listdir(short_terms_dir):
        person_path = os.path.join(short_terms_dir, person_id)
        if not os.path.isdir(person_path):
            continue

        # 2) For each person, collect the subfolders which are the video UIDs
        video_uids = []
        for video_uid in os.listdir(person_path):
            video_path = os.path.join(person_path, video_uid)
            if os.path.isdir(video_path):
                # This is a folder, presumably a video UID
                video_uids.append(video_uid)

        if video_uids:
            person_videos_map[person_id] = video_uids

    return person_videos_map

def process_target_clip(target_clip: str) -> Optional[Dict[str, int]]:
    base_name = os.path.basename(target_clip)  # e.g., someUid_123_456.mp4
    splits = base_name.split("_")
    if len(splits) < 3:
        return None

    # The parent folder might be the "original_video_uid"
    # e.g. ".../f324ccbc-bef5-4d68-9722-cc99bdaaa660/someUid_123_456.mp4"
    parent_dir = os.path.dirname(target_clip)
    video_uid = parent_dir[-1]  # e.g. 'f324ccbc-bef5-4d68-9722-cc99bdaaa660'
    # Then parse the frames from the last two elements
    try:
        start_frame = int(splits[-2])
        end_frame = int(splits[-1].replace(".mp4", ""))
    except ValueError:
        return None

    return {
        "original_video_uid": video_uid,
        "start_frame": start_frame,
        "end_frame": end_frame
    }

def init():
    """
    Initialize any global variables or settings.
    """
    # Load the person-videos map from the JSON file if it exists
    person_videos_map_path = os.path.join(args.data_path, "person_videos_map.json")
    if os.path.exists(person_videos_map_path):
        print(f"Loading person-videos map from: {person_videos_map_path}")
        with open(person_videos_map_path, "r") as f:
            person_videos_map = json.load(f)
    else:
        # If it doesn't exist, build it
        short_terms_dir = r"D:\OneDrive - Microsoft\dataset\Ego4D\new_dataset\short_terms"
        person_videos_map = build_person_videos_map(short_terms_dir)
        # Save the person-videos map to a JSON file
        print(f"Saving person-videos map to: {person_videos_map_path}")
        with open(person_videos_map_path, "w") as f:
            json.dump(person_videos_map, f, indent=4)

    print("Loading narration data from: ", os.path.join(args.data_path, "GPT_narration.json"))
    # Load narration data (which presumably has "narration_pass_1", "narration_pass_2", etc.)
    with open(os.path.join(args.data_path, "GPT_narration.json"), "r", encoding="utf-8") as f:
        narration_data = json.load(f)


    return person_videos_map, narration_data


def get_relevant_object(sentence: str) -> str:
    doc = nlp(sentence)

    # List of articles to exclude (lowercase comparison)
    articles_to_exclude = {"the", "a"}

    # 1. Look for a direct object that's not a pronoun
    for token in doc:
        if token.dep_ == "dobj":
            # Gather compound words + the main token
            children = [child for child in token.children if child.dep_ == "compound"]
            # Sort children by their position in the sentence
            children = sorted(children, key=lambda x: x.i)

            # Combine them into one phrase, filtering out "the" or "a"
            phrase_tokens = children + [token]
            filtered_tokens = [t.text for t in phrase_tokens if t.text.lower() not in articles_to_exclude]

            return " ".join(filtered_tokens)

    # 2. If no direct object found, try noun chunks that aren't pronouns
    for chunk in doc.noun_chunks:
        if chunk.root.dep_ in ("dobj", "pobj") and chunk.root.pos_ != "PRON":
            # Filter out "the" or "a"
            filtered_chunk = [t.text for t in chunk if t.text.lower() not in articles_to_exclude]
            return " ".join(filtered_chunk)

    # 3. If nothing is found, return empty string
    return ""


def build_attribute_frequency(relevant_objects_attrs):

    freq = defaultdict(lambda: defaultdict(int))

    for item in relevant_objects_attrs:
        # item might be a tuple => (candidate_uid, ts_frame, object_attributes)
        # or might be a dict structure; adapt as needed:
        #   candidate_obj_attrs = item['object_attributes']  # if dict
        # or:
        candidate_obj_attrs = item[2]  # if tuple

        for attr_name, attr_value in candidate_obj_attrs.items():
            # Only handle string-like attributes; adapt if needed
            if isinstance(attr_value, str):
                freq[attr_name][attr_value] += 1

    # Convert from defaultdict to normal dict
    freq = {
        attr_key: dict(value_counts) for attr_key, value_counts in freq.items()
    }
    return freq

def score_candidate_object(candidate_obj_attrs, frequency_table):
    """
    Given a single candidate object's attributes and a precomputed frequency table,
    compute the sum of frequencies for each attribute-value pair that appears
    in the candidate object.
    """
    score = 0
    for attr_name, attr_value in candidate_obj_attrs.items():
        if attr_name in frequency_table and attr_value in frequency_table[attr_name]:
            score += frequency_table[attr_name][attr_value]
    return score

def select_best_candidate(relevant_objects_attrs):

    # 1. Build the frequency table
    frequency_table = build_attribute_frequency(relevant_objects_attrs)

    best_candidate = None
    best_score = -1

    for item in relevant_objects_attrs:
        candidate_uid, ts_frame, candidate_obj_attrs = item
        # 2. Score this candidate
        s = score_candidate_object(candidate_obj_attrs, frequency_table)
        # 3. Track the highest score
        if s > best_score:
            best_score = s
            best_candidate = (candidate_uid, ts_frame, candidate_obj_attrs)

    return best_candidate, best_score

def gather_object_attribute_stats(candidate_objects: list) -> dict:

    # attribute_frequency[attribute_key][attribute_value] = count of occurrences
    attribute_frequency = defaultdict(lambda: defaultdict(int))

    for _,_,obj_attrs in candidate_objects:
        # Extract the object's attributes dictionary
        for attr_key, attr_value in obj_attrs.items():
            # Increment the frequency count for this attribute value
            attribute_frequency[attr_key][attr_value] += 1

    # Convert the nested defaultdicts to plain dictionaries for easier downstream usage
    attribute_frequency = {
        attr_key: dict(value_counts)
        for attr_key, value_counts in attribute_frequency.items()
    }

    return attribute_frequency


def create_summary_from_stats(stats: dict) -> str:

    lines = []
    for attr_key, counts in stats.items():
        # Sort attribute values by frequency (descending)
        sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True)
        # Build a line like "color: white(5), black(3), red(1)"
        attr_description = ", ".join(f"{val}({count})" for val, count in sorted_counts)
        lines.append(f"{attr_key}: {attr_description}")

    return "\n".join(lines)

def get_attr_summary(relevant_objects_attrs):
    # 1. Gather statistics
    print(f"Gathering stats for {len(relevant_objects_attrs)} relevant objects.")
    stats = gather_object_attribute_stats(relevant_objects_attrs)

    # 2. Create a simple text summary
    summary = create_summary_from_stats(stats)

    print("\n=== Summary (for GPT or other usage) ===")
    print(summary)

    return summary


def target_video_description_generator(
        user_query: str,
        best_candidate_image,  # This should be an image object (or path) you pass to encode_image(...)
        personal_object_summary: str  # This is the attribute-frequency summary (personal usage info)
) -> str:

    text_only_flag = True if args.gpt_prompt.split("_")[-1] == "text" else False
    print(f"Using text-only mode: {text_only_flag}")
    sys_prompt = eval(args.gpt_prompt)

    user_prompt = '''
    <Input>
        {   
            "User Query": %s,
            "Visual Reference": <image_url>
            "Object Attributes Summary": %s.
        }
    ''' % (user_query, personal_object_summary)



    # 3) Call your existing function. The function internally:
    #    – Encodes the image
    #    – Appends to messages with the role "user"
    #    – Uses GPT-4o to get a chain-of-thought style completion
    if text_only_flag:
        # print("Using text-only mode.")
        resp = cloudgpt_api.openai_completion_text(
            sys_prompt=sys_prompt,
            user_prompt=user_prompt,
            engine=args.openai_engine,
            max_tokens=4096,
            temperature=0
        )
    else:
        # This is the function that handles the image encoding and message appending
        # cloudgpt_api.openai_completion_vision_CoT(...)
        # It should be similar to the text-only version but with image handling
        # and possibly different message formatting.
        resp = cloudgpt_api.openai_completion_vision_CoT(
            sys_prompt=sys_prompt,
            user_prompt=user_prompt,
            image=best_candidate_image,  # Your image object to be encoded
            engine=args.openai_engine,
            max_tokens=4096,
            temperature=0
        )

    # Remove <Response> tags if present
    if resp.startswith('<Response>'):
        resp = resp.replace('<Response>', '').replace('</Response>', '').strip()

    # Remove json tags if present
    if resp.startswith('```json'):
        resp = resp.replace('```json', '').replace('```', '').strip()

    ## extract target image description
    # json.loads(resp) have the error
    try:
        resp_dict = json.loads(resp)
    except json.JSONDecodeError:
        print(f"Error decoding JSON: {resp}")
        return None



    # print(resp_dict.values())
    description = ""
    if 'Target Video Description' in resp_dict:
        description = resp_dict['Target Video Description']
        # print(description)

    # print("Target Description: ", description)

    return description


def load_embeddings(path, emb_path=None):
    df = pd.read_csv(path)
    if emb_path:
        embeddings = torch.load(emb_path)
        embeddings = embeddings.to(CUDA_DEVICE)
    else:
        embeddings = df["clip_embeddings"].apply(
            lambda emb: np.array(literal_eval(emb))
        )
        embeddings = np.stack(embeddings)
        embeddings = torch.tensor(embeddings, device=CUDA_DEVICE, dtype=torch.float32)
    return df, embeddings

def compute_recall_at_k(
        query_embeddings,
        candidate_embeddings,
        ground_truth,
        k,
        min_gallery_size,
        modalities,
):
    total_relevant = 0
    total_retrieved_relevant = 0
    all_retrieval_results_by_query = []  # save all retrieval results for each query

    num_queries = len(ground_truth)
    for i in range(num_queries):
        # 获取真实目标ID
        target_item = ground_truth[i]  # ground_truth is already a list of strings, directly get it
        relevant_items = {target_item}  # convert to a set for faster lookup

        # get the query embedding for the current modality
        query_embedding = query_embeddings[modalities[0]][i]

        # calculate cosine similarity between query and candidate embeddings
        similarities = []
        for target_video_id, candidate_embedding in candidate_embeddings.items():
            # 计算查询嵌入和候选嵌入之间的余弦相似度
            similarity = torch.matmul(query_embedding.unsqueeze(0), candidate_embedding.T).squeeze(0)
            similarities.append((target_video_id, similarity.item()))

        # sort by similarity (highest first) and get all results
        similarities.sort(key=lambda x: x[1], reverse=True)
        all_ids = [item[0] for item in similarities]  # 获取所有检索结果
        all_retrieval_results_by_query.append(all_ids)  # 存储每个查询的所有检索结果

        # calculate recall (still using the top k for recall calculation)
        retrieved_items = set(all_ids[:k])
        relevant_retrieved = relevant_items.intersection(retrieved_items)

        total_relevant += len(relevant_items)  # each query has one relevant item
        total_retrieved_relevant += len(relevant_retrieved)  # if relevant_retrieved is empty, it will be 0

    recall_at_k = total_retrieved_relevant / total_relevant if total_relevant > 0 else 0
    return recall_at_k, all_retrieval_results_by_query


def main():
    print(f"Running {args.models} retrieval with {modalities} using {evaluation} evaluation.")
    seed_everything(123)

    tqdm.pandas()
    models = {}
    frame_loaders = {}
    tokenizers = {}
    model_forwards = {}
    text_forwards = {}

    # Initialize models
    if "blip" in args.models:
        model_blip, frame_loader_blip, tokenizer_blip = init_BLIP(
            checkpoint_path=config["blip"]["ckpt_path"],
            query_frame_method=query_frame_method,
            num_query_frames=num_query_frames,
            device=CUDA_DEVICE,
        )
        models["blip"] = model_blip
        frame_loaders["blip"] = frame_loader_blip
        tokenizers["blip"] = tokenizer_blip
        model_forwards["blip"] = forward_blip
        text_forwards["blip"] = forward_blip_text

    if "egovlpv2" in args.models:
        model_egovlpv2, frame_loader_egovlpv2, tokenizer_egovlpv2 = init_EgoVLPv2(
            checkpoint_path=config["egovlpv2"]["ckpt_path"], device=CUDA_DEVICE
        )
        models["egovlpv2"] = model_egovlpv2
        frame_loaders["egovlpv2"] = frame_loader_egovlpv2
        tokenizers["egovlpv2"] = tokenizer_egovlpv2
        model_forwards["egovlpv2"] = forward_egovlpv2
        text_forwards["egovlpv2"] = forward_egovlpv2_text

    if "clip" in args.models:
        model_clip, frame_loader_clip, tokenizer_clip = init_CLIP(
            query_frame_method=query_frame_method,
            num_query_frames=num_query_frames,
            device=CUDA_DEVICE,
        )
        models["clip"] = model_clip
        frame_loaders["clip"] = frame_loader_clip
        tokenizers["clip"] = tokenizer_clip
        model_forwards["clip"] = forward_clip
        text_forwards["clip"] = partial(forward_clip_text, tokenizer=tokenizer_clip)

    if "languagebind" in args.models:
        model_languagebind, frame_loader_languagebind, tokenizer_languagebind = (
            init_languagebind(device=CUDA_DEVICE)
        )
        models["languagebind"] = model_languagebind
        frame_loaders["languagebind"] = frame_loader_languagebind
        tokenizers["languagebind"] = tokenizer_languagebind
        model_forwards["languagebind"] = forward_languagebind
        text_forwards["languagebind"] = forward_languagebind_text

    # Load preprocessed embeddings
    embeddings_dict = {}
    for model in args.models:
        embeddings_dict[model] = load_preprocessed_embeddings(args.embeddings_dir, model)

    # Load CSV data
    df = pd.read_csv(args.csv_path)

    query_embeddings = {}
    for model in args.models:
        query_embeddings[model] = {}
        for modality in args.modalities:
            query_embeddings[model][modality] = defaultdict(list)  # 使用defaultdict按person_id组织

    # Initialize ground_truth dictionary to store the target video IDs by person_id
    ground_truth_by_person = defaultdict(list)

    # Initialize retrieval_data to store the retrieval results
    queries_by_person = defaultdict(list)

    # Load the CSV file
    csv_file_path = args.input_csv
    csv_file_path = os.path.abspath(csv_file_path)
    print(f"Loading CSV file from: {csv_file_path}")
    with open(csv_file_path, 'r') as csv_file:
        reader = csv.DictReader(csv_file)
        rows = list(reader)
        print(f"Loaded {len(rows)} rows from the CSV file.")

    person_videos_map, narration_data = init()
    prompt_version = args.gpt_prompt.split("_")[-1]

    # add a new column to the CSV file
    new_column_name0 = "long_term_similarity_score"
    # Check if the new column already exists
    if new_column_name0 in rows[0]:
        print(f"Column '{new_column_name0}' already exists in the CSV file. Skipping addition.")
    else:
        # Add the new column with empty values
        for row in rows:
            row[new_column_name0] = ""
        print(f"Added new column '{new_column_name0}' to the CSV file.")

    # add a new column to the CSV file and save to the same path
    new_column_name1 = "target_video_description"
    # Check if the new column already exists
    if new_column_name1 in rows[0]:
        print(f"Column '{new_column_name1}' already exists in the CSV file. Skipping addition.")
    else:
        # Add the new column with empty values
        for row in rows:
            row[new_column_name1] = ""
        print(f"Added new column '{new_column_name1}' to the CSV file.")

    # add a new column to the CSV file
    new_column_name2 = "personal_object_summary"
    # Check if the new column already exists
    if new_column_name2 in rows[0]:
        print(f"Column '{new_column_name2}' already exists in the CSV file. Skipping addition.")
    else:
        # Add the new column with empty values
        for row in rows:
            row[new_column_name2] = ""
        print(f"Added new column '{new_column_name2}' to the CSV file.")

    for idx, row in enumerate(rows):
        print(f"\n--- Processing row {idx + 1}/{len(rows)} ---")

        # # print the first 1 rows
        # for row in rows[:1]:X
        #     print(row) #  {'person_id': '39', 'query': 'Where did i put the screw driver?', 'object': 'screw driver', 'target_clip': 'D:\\OneDrive - Microsoft\\dataset\\Ego4D\\new_dataset\\short_terms\\39\\72295d26-19f7-4c6a-874e-85ba8654861e\\cc2d7790-67f7-4e52-9fa9-33121c9431a2_3094_3106.mp4', 'object_attribute': '{"major_category": "tool", "subcategory": "screwdriver", "color": "blue", "shape": "cylindrical", "material": "metal", "texture": "smooth", "size": "medium", "brand": "unknown", "style": "minimalist", "pattern": "solid", "feature": "lightweight", "usage": "professional", "status": "new"}', 'original_video_CLIP': '', 'original_object_attribute': ''}

        # Extract relevant fields from the row
        person_id = row.get("person_id", "")

        # if person_id not in ["30", "137"]:
        #     print(f" [WARN] Person ID {person_id} is not in the specified range. Skipping.")
        #     continue

        target_clip = row.get("target_clip", "")
        target_object = row.get("object", "")  # e.g., "bag"
        target_attr_str = row.get("object_attribute", "")  # JSON describing the short-term object's GPT attributes
        user_query = row.get("query", "")  # e.g., 'Where did i put the screw driver?'

        # The object is in the user_query, should find the object in the query
        object_in_query = get_relevant_object(user_query)
        print(f"User query: {user_query}")  # User query: Where did i put the screw driver?
        print(f"Object in query: {object_in_query}")  # Object in query: screw driver

        # Parse to get short-term video UID
        clip_info = process_target_clip(target_clip)
        # print(f"Clip info: {clip_info}") # Clip info: {'original_video_uid': '579bcf96-f29d-4f02-bb10-ef8b9404e362', 'start_frame': 57277, 'end_frame': 57320}

        # 1) Extract the target clip information
        shortterm_uid = clip_info["original_video_uid"]
        print(f" Short-term clip UID: {shortterm_uid}")

        # 2) Identify the person's other videos that might be "long-term"
        all_vids_for_person = person_videos_map.get(person_id, [])

        # Exclude the short-term video (we want a different video for "long-term")
        candidate_longterm_uids = [v for v in all_vids_for_person if v != shortterm_uid]

        if not candidate_longterm_uids:
            print(f" [WARN] No separate long-term videos for person={person_id} or mapping not found.")
            continue

        relevant_objects_attrs = []
        for candidate_uid in candidate_longterm_uids:
            if candidate_uid not in narration_data:
                # This means we have no GPT_narration data for that video
                print(f" [WARN] No narration data for {candidate_uid}. Skipping.")
                continue

            # Combine pass_1 + pass_2 narrations (if they exist)
            candidate_narrs = []
            video_obj = narration_data[candidate_uid]
            if "narration_pass_1" in video_obj:
                candidate_narrs.extend(video_obj["narration_pass_1"].get("narrations", []))
            if "narration_pass_2" in video_obj:
                candidate_narrs.extend(video_obj["narration_pass_2"].get("narrations", []))

            candidate_video_path = os.path.join(r"E:\datasets\Ego4D\ego4d_data\v2\full_scale", f"{candidate_uid}.mp4")
            if not os.path.exists(candidate_video_path):
                print(f"[WARN] Original video not found: {candidate_video_path}")
                continue


        if len(relevant_objects_attrs) == 0:
            print(f" [WARN] No relevant objects found in candidate long-term videos.")
            continue

        # Identify the best candidate
        best_candidate, best_score = select_best_candidate(relevant_objects_attrs)
        if best_candidate:
            best_candidate_uid, best_ts_frame, best_candidate_attrs = best_candidate
            print("Best Candidate UID:", best_candidate_uid)
            print("Best Candidate Score:", best_score)

            best_clip_filename = f"source_{best_candidate_uid}_{best_ts_frame}_{best_ts_frame + args.num_frame - 1}.mp4"
            print("Clip Filename:", best_clip_filename)
            print("Timestamp Frame:", best_ts_frame)
            target_clip_path = target_clip
            print("Target Clip Path:", target_clip_path)
            print("Object Attributes:", best_candidate_attrs)
            best_candidate_video_path = os.path.join(r"E:\datasets\Ego4D\ego4d_data\v2\full_scale",
                                                     f"{best_candidate_uid}.mp4")

            # Now we get the mid-frame of that full clip
            frame_np = get_mid_frame_as_np_array(best_candidate_video_path, best_ts_frame, num_frames=args.num_frame)
            if frame_np is None:
                continue

        else:
            print("No candidate found.")
            continue

        personal_object_summary = get_attr_summary(relevant_objects_attrs)

        target_description = target_video_description_generator(
            user_query=user_query,
            best_candidate_image=frame_np,
            personal_object_summary=personal_object_summary
        )

        person_id = int(row["person_id"])
        # if person_id not in [137, 30]:
        #     continue
        query_info = {
            "query_text": user_query,
            "target_clip_path": target_clip_path,
            "target_description": target_description,
            "target_object": target_object,
            "object_attributes": best_candidate_attrs
        }
        target_clip_id = Path(row["target_clip"]).stem
        print(f"Processing {target_clip_id} for person_id {person_id} with description: {target_description}")
        query_text = row["query"]  # Extracted query text

        if pd.isna(target_description):  # Check if the description is NaN
            print(f"Warning: Missing description for {target_clip_id}.")
            continue

        # Add the target clip to the ground truth for the person_id
        ground_truth_by_person[person_id].append(target_clip_id)

        queries_by_person[person_id].append(query_info)

        with torch.no_grad():
            # Get the embeddings from preprocessed target video embedding (text)
            for modality, model in zip(modalities, args.models):
                query_embedding = model_forwards[model](
                    modality,
                    models[model],
                    tokenizers[model],
                    None,  # Video input not needed for text
                    target_description,  # Use target description as text
                    None,  # Video path not needed for text embeddings
                    frame_loaders[model],
                    fusion,
                    num_query_frames,
                    query_frame_method,
                    use_precomputed=(not no_precomputed),
                )

                # Store query embeddings for later use
                query_embeddings[model][modality][person_id].append(query_embedding)

                # Clear GPU memory
                torch.cuda.empty_cache()

    # Initialize recall_results to store recall for each person and model
    records = []
    recall_results_by_person_and_model = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

    # Initialize candidate_set_sizes to store the size of candidate sets for each person_id
    candidate_set_sizes = defaultdict(list)

    candidate_durations = defaultdict(list)

    # Compute recall for each person_id
    for person_id, person_ground_truth in ground_truth_by_person.items():
        print(f"\nComputing recall for person_id {person_id} with {len(person_ground_truth)} ground truth items.")
        for model in args.models:
            if person_id not in embeddings_dict[model]:
                print(f"Skipping: person_id {person_id} not found in model {model}")
                continue

            candidate_embeddings = embeddings_dict[model][person_id]
            if not candidate_embeddings:
                print(f"Skipping: No candidate embeddings for person_id {person_id} in model {model}")
                continue

            # Get the candidate embeddings for the current person_id
            candidate_set_sizes[person_id].append(len(candidate_embeddings))

            current_person_query_embeddings = {
                modality: query_embeddings[model][modality][person_id]
                for modality in modalities
            }

            # Get all query embeddings for the current person_id
            _, all_retrieval_results = compute_recall_at_k(
                current_person_query_embeddings,
                candidate_embeddings,
                person_ground_truth,
                max(recalls),
                min_gallery_size,
                modalities,
            )

            for k in recalls:
                recall_at_k = 0
                total_queries = len(person_ground_truth)

                for idx, true_id in enumerate(person_ground_truth):
                    if idx >= len(all_retrieval_results):
                        continue

                    top_k_results = all_retrieval_results[idx][:k]
                    if true_id in top_k_results:
                        recall_at_k += 1

                recall_at_k = recall_at_k / total_queries if total_queries > 0 else 0
                recall_results_by_person_and_model[person_id][model][k] = recall_at_k

            # Store retrieval results for each query
            for idx, query_info in enumerate(queries_by_person[person_id]):
                if idx >= len(all_retrieval_results):
                    continue

                retrieval_ids = all_retrieval_results[idx]
                true_id = query_info["target_clip_id"]

                if true_id in retrieval_ids:
                    rank = retrieval_ids.index(true_id) + 1
                else:
                    rank = None

                label = "good" if (rank is not None and rank <= 10) else "bad"

                # Store the retrieval results
                target_clip = Path(query_info["target_clip_path"]).stem

                record = {
                    "model_name": model,
                    "person_id": person_id,
                    "query": query_info["query_text"],
                    "target_video_description": query_info["target_description"],
                    "target_clip": target_clip,
                    "retrieval_results": retrieval_ids,
                    "k": rank,
                    "label": label
                }

                records.append(record)

    # Optionally print recall for each person_id and recall value in a concise format for each model
    print("\nRecall for each person_id and model:")
    for person_id, recall_results in recall_results_by_person_and_model.items():
        print(f"\nPerson {person_id}:")
        for model, model_results in recall_results.items():
            recall_str = ' & '.join([f"Recall@{k}: {np.mean(model_results[k]):.4f}" for k in recalls])
            print(f"  Model {model}: {recall_str}")

    print("Total number of person_ids:", len(ground_truth_by_person))
    print("Total number of queries:", len(records))

    all_sizes = [size for sizes in candidate_set_sizes.values() for size in sizes]
    overall_avg_size = sum(all_sizes) / len(all_sizes) if all_sizes else 0
    print(f"Overall average candidate set size across all person_ids: {overall_avg_size:.2f}")

    all_durations = [duration for durations in candidate_durations.values() for duration in durations]
    if all_durations:
        overall_avg_duration = sum(all_durations) / len(all_durations)
        print(f"\nThe number on total videos: {len(all_durations)}")
        print(f"Total durations (hours): {sum(all_durations) / 3600:.2f} hours")
        print(f"Mini Video: {min(all_durations):.2f} sec, MAX Duration Video: {max(all_durations):.2f} secs")
        print(f"All Video Average length: {overall_avg_duration:.2f}Secs ({overall_avg_duration / 60:.2f} mins)")
    else:
        print("\nError")

    # Calculate and print mean recall for each model and k across all person_ids
    print("\nMean Recall results (mean across all person_ids):")
    for model in args.models:
        print(f"\nResults for model: {model}")
        mean_recall_str = ' & '.join([
            f"mRecall@{k}: {np.mean([recall_results_by_person_and_model[person_id][model][k] for person_id in recall_results_by_person_and_model]):.4f}"
            for k in recalls
        ])
        print(f"  {mean_recall_str}")


if __name__ == "__main__":
    main()
